文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
0. 测试环境
Python 3.6.9, Pytorch 1.5.0
1. 基本概念
Tensor
是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32
。
- 示例一
1 | >>> a = torch.tensor([1.0]) |
Tensor
中只有一个数字时,使用torch.Tensor.item()
可以得到一个Python数字。requires_grad
为True
时,表示需要计算Tensor
的梯度。requires_grad=False
可以用来冻结部分网络,只更新另一部分网络的参数。
- 示例二
1 | >>> a = torch.tensor([1.0, 2.0]) |
a.data
返回的是一个新的Tensor
对象b
,a, b
的id
不同,说明二者不是同一个Tensor
,但b
与a
共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b
的元素时,a
的元素也对应修改。
2. requiresgrad()与detach()
1 | >>> a = torch.tensor([1.0, 2.0]) |
requires_grad_()
requires_grad_()
函数会改变Tensor
的requires_grad
属性并返回Tensor
,修改requires_grad
的操作是原位操作(in place)。其默认参数为requires_grad=True
。requires_grad=True
时,自动求导会记录对Tensor
的操作,requires_grad_()
的主要用途是告诉自动求导开始记录对Tensor
的操作。
detach()
detach()
函数会返回一个新的Tensor
对象b
,并且新Tensor
是与当前的计算图分离的,其requires_grad
属性为False
,反向传播时不会计算其梯度。b
与a
共享数据的存储空间,二者指向同一块内存。
注:共享内存空间只是共享的数据部分,a.grad
与b.grad
是不同的。
3. torch.no_grad()
torch.no_grad()
是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
1 | 1.0, 2.0], requires_grad=True) a = torch.tensor([ |
上面的例子中,当a
的requires_grad=True
时,不使用torch.no_grad()
,c.requires_grad
为True
,使用torch.no_grad()
时,b.requires_grad
为False
,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True
会占用更多的计算资源及存储资源。
4. 总结
requires_grad_()
会修改Tensor
的requires_grad
属性。
detach()
会返回一个与计算图分离的新Tensor
,新Tensor
不会在反向传播中计算梯度,会在特定场合使用。
torch.no_grad()
更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。